#!/usr/bin/env python
# -*- coding: UTF-8 -*-

import numpy as np

from scipy.signal import convolve2d as c2d
from scipy.sparse import csc_matrix, lil_matrix
from scipy.special import expit

FLOAT_TYPE = np.float32
INT_TYPE = np.int32
U_INT_TYPE = np.uint32

def softmax(v):
    e = np.exp(v)
    return np.clip(e / np.sum(e), 0.00000001, 0.99999999)

def sigmoid(v):
    return expit(v)

def tanh(v):
    return np.tanh(v/2.)

def input_layer_backward(delta, decay, W, indexes):
    grad_W = np.zeros(W.shape, dtype=FLOAT_TYPE)
    for i in range(len(indexes)):
        grad_W[:, indexes[i]] += delta[0, :, i]
    return grad_W

def wide_convolution_layer_forward(input, window_size, n_filters, W, b):
    n_feature_maps, n_feature_rows, n_feature_cols = input.shape[0], input.shape[1], input.shape[2]
    output = np.zeros((n_filters, n_feature_rows, n_feature_cols + window_size - 1), dtype=FLOAT_TYPE)
    for i in range(n_filters):
        for j in range(n_feature_maps):
            output[i, :, :] += convolve2d(input[j, :, :], np.rot90(W[j, i, :, :], 2), mode='full')
    for i in range(n_filters):
        output[i] += b[i]
    return output

def wide_convolution_layer_backward(input_images, W, b, grad, decay, back_linear=True):
    back_grad = np.zeros(input_images.shape, dtype=FLOAT_TYPE)
    grad_W, grad_b = np.empty(W.shape, dtype=FLOAT_TYPE), np.empty(b.shape, dtype=FLOAT_TYPE)
    n_feature_maps, n_filters = W.shape[0], W.shape[1]
    for i in range(n_feature_maps):
        for j in range(n_filters):
            back_grad[i, :, :] += convolve2d_fast(grad[j, :, :], W[i, j, :, :], mode='valid')
    for i in range(n_feature_maps):
        if not back_linear:
            back_grad[i, :, :] = back_grad[i, :, :] * input_images[i, :, :] * (1 - input_images[i, :, :])
    for i in range(n_feature_maps):
        for j in range(n_filters):
            grad_W[i, j, :, :] = convolve2d_fast(np.rot90(grad[j, :, :], 2), input_images[i, :, :], mode='valid')
    grad_W += decay * W
    for i in range(n_filters):
        grad_b[i] = np.sum(grad[i])
    return back_grad, grad_W, grad_b

def k_max_pooling_image(input_image, k, b):
    kmax_index = np.sort(np.argsort(input_image)[:, -k:])
    new_image = np.empty((input_image.shape[0], k), dtype=FLOAT_TYPE)
    for j in range(input_image.shape[0]):
        new_image[j, :] = input_image[j, kmax_index[j, :]] + b[j]
    return new_image, kmax_index

def k_max_pooling_backward(grad, input_images, kmax_index, back_linear=True):
    n_feature_maps, n_feature_rows, n_feature_cols = input_images.shape
    back_grad = np.zeros(input_images.shape, dtype=FLOAT_TYPE)
    grad_b = np.zeros((n_feature_maps, n_feature_rows), dtype=FLOAT_TYPE)
    for i in range(n_feature_maps):
        for j in range(n_feature_rows):
            back_grad[i, j, kmax_index[i, j, :]] = grad[i, j, :]
            grad_b[i, j] = np.sum(grad[i, j, :])
        if not back_linear:
            back_grad[i, :, :] = back_grad[i, :, :] * input_images[i, :, :] * (1 - input_images[i, :, :])
    return back_grad, grad_b

def folding_image(input_image):
    output_image = np.empty((input_image.shape[0]/2, input_image.shape[1]), dtype=FLOAT_TYPE)
    for j in range(0, input_image.shape[1]):
        for k in range(0, input_image.shape[0], 2):
            output_image[k / 2, j] = input_image[k, j] + input_image[k + 1, j]
    return output_image

nonlinear_func = sigmoid
convolve2d = c2d